{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {},
     "inputWidgets": {},
     "nuid": "8e9b80fc-12cf-41a9-a0de-354f678b412b",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "#### Sparse mixture of experts language model from scratch inspired by (and largely based on) Andrej Karpathy's makemore (https://github.com/karpathy/makemore) :)\n",
    "\n",
    "This is a from scratch implementation of a sparse mixture of experts language model. This is inspired by and largely based on Andrej Karpathy's project 'makemore' and borrows most of the re-usable components from that implementation. Just like makemore, makeMoE is also an autoregressive character-level language model but uses the aforementioned sparse mixture of experts architecture. \n",
    "\n",
    "Just like makemore, pytorch is the only requirement (so I hope the from scratch claim is justified).\n",
    "\n",
    "Significant Changes from the makemore architecture\n",
    "\n",
    "- Sparse mixture of experts instead of the solitary feed forward neural net. \n",
    "- Top-k gating and noisy top-k gating implementations.\n",
    "- initialization - Kaiming He initialization is used here but the point of this notebook is to be hackable so you can swap in Xavier Glorot etc. and take it for a spin.\n",
    "\n",
    "Unchanged from makemore\n",
    "- dataset, preprocessing (tokenization), and the language modeling task Andrej chose originally - generate Shakespeare-like text\n",
    "- Casusal self attention implementation \n",
    "- Training loop\n",
    "- Inference logic\n",
    "\n",
    "Publications heavily referenced for this implementation: \n",
    "- Mixtral of experts: https://arxiv.org/pdf/2401.04088.pdf\n",
    "- Outrageosly Large Neural Networks: The Sparsely-Gated Mixture-Of-Experts layer: https://arxiv.org/pdf/1701.06538.pdf\n",
    "\n",
    "\n",
    "** While makeMoE_from_Scratch.ipynb gives you the end to end code and helps you develop intuition, this notebook is solely focused on training the model. So I've omitted the example implementations of various things so you can see the final code (still very much from scratch).\n",
    "\n",
    "The code was entirely developed on Databricks using a single A100 for compute. If you're running this on Databricks, you can scale this on an arbitrarily large GPU cluster with no issues in the cloud provider of your choice\n",
    "\n",
    "I chose to use mlflow (which comes pre-installed in Databricks. You can pip install easily elsewhere) as I find it helpful to track and log all the metrics necessary. This is entirely optional\n",
    "\n",
    "Please note that the implementation emphasizes readability and hackability vs performance, so there are many ways in which you could improve this. Please try and let me know \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "2f4a58a8-bd4c-40de-a4a9-95457842db0b",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "![mixture of experts overview](https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/images/moe.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "35b3daa3-3b3b-47af-b3e7-be95878f9e06",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "#Using mlflow is entirely optional. I personally like to use MLFlow to track and log everything. If you're using Databricks, it comes pre-installed.\n",
    "%pip install mlflow"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "5e1a3e38-8717-42ec-9bbc-71d3712c1c68",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "#Import the necessary packages and set seed for reproducibility. For this notebook, pytorch is all you need\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.nn import functional as F\n",
    "torch.manual_seed(42)\n",
    "#Optional\n",
    "import mlflow"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "45143d84-28c7-463d-9fb5-e21122842600",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "# We always start with a dataset to train on. Let's download the tiny shakespeare dataset\n",
    "!wget https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/input.txt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {},
     "inputWidgets": {},
     "nuid": "dde3273f-0519-4108-ba84-dfd99e020722",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "### Causal Scaled Dot Product Self Attention \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "e435d0cf-1383-446a-9026-cd80b4266019",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "![scaled dot product self attention](https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/images/self_attention.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {},
     "inputWidgets": {},
     "nuid": "a7764385-26e9-4d75-9aa7-ce011023e24e",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "### Top-k Gating"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {},
     "inputWidgets": {},
     "nuid": "d3fca4df-4c47-4e9a-98cd-08cf8ccf7726",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "![top k gating](https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/images/topk.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {},
     "inputWidgets": {},
     "nuid": "9b1f67bd-2930-4fad-9e9f-0ce2f278c4c4",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "### Noisy Top-k Gating"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {},
     "inputWidgets": {},
     "nuid": "e05b3306-b89f-4ebc-901b-f16398a925c2",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "![noisy top-k gating](https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/images/noisytopkgating.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {},
     "inputWidgets": {},
     "nuid": "d99b5dce-301e-4380-8263-b5cfb4136ab2",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "![sparse MoE](https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/images/sparseMoEfinal.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {},
     "inputWidgets": {},
     "nuid": "da5f3be4-f155-4d6c-bcbd-2f88a088261f",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "### Final code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "0eaf71cd-c77e-40c7-b5be-e364e91685cf",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "#First defining hyperparameters and boiler plate code. Imports and data preparation code is repeated for convenience\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.nn import functional as F\n",
    "from torch.nn import init\n",
    "\n",
    "# hyperparameters\n",
    "batch_size = 16 # how many independent sequences will we process in parallel?\n",
    "block_size = 32 # what is the maximum context length for predictions?\n",
    "max_iters = 5000\n",
    "eval_interval = 100\n",
    "learning_rate = 1e-3\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "eval_iters = 400\n",
    "head_size = 16\n",
    "n_embed = 128\n",
    "n_head = 8\n",
    "n_layer = 8\n",
    "dropout = 0.1\n",
    "num_experts = 8\n",
    "top_k = 2\n",
    "# ------------\n",
    "\n",
    "torch.manual_seed(42)\n",
    "\n",
    "with open('input.txt', 'r', encoding='utf-8') as f:\n",
    "    text = f.read()\n",
    "\n",
    "# here are all the unique characters that occur in this text\n",
    "chars = sorted(list(set(text)))\n",
    "vocab_size = len(chars)\n",
    "# create a mapping from characters to integers\n",
    "stoi = { ch:i for i,ch in enumerate(chars) }\n",
    "itos = { i:ch for i,ch in enumerate(chars) }\n",
    "encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers\n",
    "decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string\n",
    "\n",
    "# Train and test splits\n",
    "data = torch.tensor(encode(text), dtype=torch.long)\n",
    "n = int(0.9*len(data)) # first 90% will be train, rest val\n",
    "train_data = data[:n]\n",
    "val_data = data[n:]\n",
    "\n",
    "# data loading\n",
    "def get_batch(split):\n",
    "    # generate a small batch of data of inputs x and targets y\n",
    "    data = train_data if split == 'train' else val_data\n",
    "    ix = torch.randint(len(data) - block_size, (batch_size,))\n",
    "    x = torch.stack([data[i:i+block_size] for i in ix])\n",
    "    y = torch.stack([data[i+1:i+block_size+1] for i in ix])\n",
    "    x, y = x.to(device), y.to(device)\n",
    "    return x, y\n",
    "\n",
    "@torch.no_grad()\n",
    "def estimate_loss():\n",
    "    out = {}\n",
    "    model.eval()\n",
    "    for split in ['train', 'val']:\n",
    "        losses = torch.zeros(eval_iters)\n",
    "        for k in range(eval_iters):\n",
    "            X, Y = get_batch(split)\n",
    "            logits, loss = model(X, Y)\n",
    "            losses[k] = loss.item()\n",
    "        out[split] = losses.mean()\n",
    "    model.train()\n",
    "    return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "ee1180f7-5004-4425-87fe-9a81a17b9024",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "class Head(nn.Module):\n",
    "    \"\"\" one head of self-attention \"\"\"\n",
    "\n",
    "    def __init__(self, head_size):\n",
    "        super().__init__()\n",
    "        self.key = nn.Linear(n_embed, head_size, bias=False)\n",
    "        self.query = nn.Linear(n_embed, head_size, bias=False)\n",
    "        self.value = nn.Linear(n_embed, head_size, bias=False)\n",
    "        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))\n",
    "\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "\n",
    "    def forward(self, x):\n",
    "        B,T,C = x.shape\n",
    "        k = self.key(x)   # (B,T,C)\n",
    "        q = self.query(x) # (B,T,C)\n",
    "        # compute attention scores (\"affinities\")\n",
    "        wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)\n",
    "        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)\n",
    "        wei = F.softmax(wei, dim=-1) # (B, T, T)\n",
    "        wei = self.dropout(wei)\n",
    "        # perform the weighted aggregation of the values\n",
    "        v = self.value(x) # (B,T,C)\n",
    "        out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)\n",
    "        return out\n",
    "    \n",
    "#Multi-Headed Self Attention\n",
    "class MultiHeadAttention(nn.Module):\n",
    "    \"\"\" multiple heads of self-attention in parallel \"\"\"\n",
    "\n",
    "    def __init__(self, num_heads, head_size):\n",
    "        super().__init__()\n",
    "        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])\n",
    "        self.proj = nn.Linear(n_embed, n_embed)\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = torch.cat([h(x) for h in self.heads], dim=-1)\n",
    "        out = self.dropout(self.proj(out))\n",
    "        return out\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "03611a92-aaa2-4e0d-9755-cba56f96c794",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "#Expert module\n",
    "class Expert(nn.Module):\n",
    "    \"\"\" An MLP is a simple linear layer followed by a non-linearity i.e. each Expert \"\"\"\n",
    "\n",
    "    def __init__(self, n_embed):\n",
    "        super().__init__()\n",
    "        self.net = nn.Sequential(\n",
    "            nn.Linear(n_embed, 4 * n_embed),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(4 * n_embed, n_embed),\n",
    "            nn.Dropout(dropout),\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.net(x)\n",
    "    \n",
    "#noisy top-k gating\n",
    "class NoisyTopkRouter(nn.Module):\n",
    "    def __init__(self, n_embed, num_experts, top_k):\n",
    "        super(NoisyTopkRouter, self).__init__()\n",
    "        self.top_k = top_k\n",
    "        #layer for router logits\n",
    "        self.topkroute_linear = nn.Linear(n_embed, num_experts)\n",
    "        self.noise_linear =nn.Linear(n_embed, num_experts)\n",
    "\n",
    "    \n",
    "    def forward(self, mh_output):\n",
    "        # mh_ouput is the output tensor from multihead self attention block\n",
    "        logits = self.topkroute_linear(mh_output)\n",
    "\n",
    "        #Noise logits\n",
    "        noise_logits = self.noise_linear(mh_output)\n",
    "\n",
    "        #Adding scaled unit gaussian noise to the logits\n",
    "        noise = torch.randn_like(logits)*F.softplus(noise_logits)\n",
    "        noisy_logits = logits + noise\n",
    "\n",
    "        top_k_logits, indices = noisy_logits.topk(self.top_k, dim=-1)\n",
    "        zeros = torch.full_like(noisy_logits, float('-inf'))\n",
    "        sparse_logits = zeros.scatter(-1, indices, top_k_logits)\n",
    "        router_output = F.softmax(sparse_logits, dim=-1)\n",
    "        return router_output, indices\n",
    "    \n",
    "#Now create the sparse mixture of experts module\n",
    "class SparseMoE(nn.Module):\n",
    "    def __init__(self, n_embed, num_experts, top_k):\n",
    "        super(SparseMoE, self).__init__()\n",
    "        self.router = NoisyTopkRouter(n_embed, num_experts, top_k)\n",
    "        self.experts = nn.ModuleList([Expert(n_embed) for _ in range(num_experts)])\n",
    "        self.top_k = top_k\n",
    "\n",
    "    def forward(self, x):\n",
    "        gating_output, indices = self.router(x)\n",
    "        final_output = torch.zeros_like(x)\n",
    "\n",
    "        # Reshape inputs for batch processing\n",
    "        flat_x = x.view(-1, x.size(-1))\n",
    "        flat_gating_output = gating_output.view(-1, gating_output.size(-1))\n",
    "\n",
    "        # Process each expert in parallel\n",
    "        for i, expert in enumerate(self.experts):\n",
    "            # Create a mask for the inputs where the current expert is in top-k\n",
    "            expert_mask = (indices == i).any(dim=-1)\n",
    "            flat_mask = expert_mask.view(-1)\n",
    "\n",
    "            if flat_mask.any():\n",
    "                expert_input = flat_x[flat_mask]\n",
    "                expert_output = expert(expert_input)\n",
    "\n",
    "                # Extract and apply gating scores\n",
    "                gating_scores = flat_gating_output[flat_mask, i].unsqueeze(1)\n",
    "                weighted_output = expert_output * gating_scores\n",
    "\n",
    "                # Update final output additively by indexing and adding\n",
    "                final_output[expert_mask] += weighted_output.squeeze(1)\n",
    "\n",
    "        return final_output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "bfdff2bb-092f-41c8-9a33-c84e6f8d6633",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "#First create a self attention + mixture of experts block, that may be repeated several number of times \n",
    "#Copy pasting key architecture variables for clarity\n",
    "\n",
    "class Block(nn.Module):\n",
    "    \"\"\" Mixture of Experts Transformer block: communication followed by computation (multi-head self attention + SparseMoE) \"\"\"\n",
    "\n",
    "    def __init__(self, n_embed, n_head, num_experts, top_k):\n",
    "        # n_embed: embedding dimension, n_head: the number of heads we'd like\n",
    "        super().__init__()\n",
    "        head_size = n_embed // n_head\n",
    "        self.sa = MultiHeadAttention(n_head, head_size)\n",
    "        self.smoe = SparseMoE(n_embed, num_experts, top_k)\n",
    "        self.ln1 = nn.LayerNorm(n_embed)\n",
    "        self.ln2 = nn.LayerNorm(n_embed)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x + self.sa(self.ln1(x))\n",
    "        x = x + self.smoe(self.ln2(x))\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "2d32a276-d0cc-4808-90d7-62441771af44",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "#Finally putting it all together to crease a sparse mixture of experts language model\n",
    "class SparseMoELanguageModel(nn.Module):\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        # each token directly reads off the logits for the next token from a lookup table\n",
    "        self.token_embedding_table = nn.Embedding(vocab_size, n_embed)\n",
    "        self.position_embedding_table = nn.Embedding(block_size, n_embed)\n",
    "        self.blocks = nn.Sequential(*[Block(n_embed, n_head=n_head, num_experts=num_experts,top_k=top_k) for _ in range(n_layer)])\n",
    "        self.ln_f = nn.LayerNorm(n_embed) # final layer norm\n",
    "        self.lm_head = nn.Linear(n_embed, vocab_size)\n",
    "\n",
    "    def forward(self, idx, targets=None):\n",
    "        B, T = idx.shape\n",
    "\n",
    "        # idx and targets are both (B,T) tensor of integers\n",
    "        tok_emb = self.token_embedding_table(idx) # (B,T,C)\n",
    "        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)\n",
    "        x = tok_emb + pos_emb # (B,T,C)\n",
    "        x = self.blocks(x) # (B,T,C)\n",
    "        x = self.ln_f(x) # (B,T,C)\n",
    "        logits = self.lm_head(x) # (B,T,vocab_size)\n",
    "\n",
    "        if targets is None:\n",
    "            loss = None\n",
    "        else:\n",
    "            B, T, C = logits.shape\n",
    "            logits = logits.view(B*T, C)\n",
    "            targets = targets.view(B*T)\n",
    "            loss = F.cross_entropy(logits, targets)\n",
    "\n",
    "        return logits, loss\n",
    "\n",
    "    def generate(self, idx, max_new_tokens):\n",
    "        # idx is (B, T) array of indices in the current context\n",
    "        for _ in range(max_new_tokens):\n",
    "            # crop idx to the last block_size tokens\n",
    "            idx_cond = idx[:, -block_size:]\n",
    "            # get the predictions\n",
    "            logits, loss = self(idx_cond)\n",
    "            # focus only on the last time step\n",
    "            logits = logits[:, -1, :] # becomes (B, C)\n",
    "            # apply softmax to get probabilities\n",
    "            probs = F.softmax(logits, dim=-1) # (B, C)\n",
    "            # sample from the distribution\n",
    "            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)\n",
    "            # append sampled index to the running sequence\n",
    "            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)\n",
    "        return idx"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {},
     "inputWidgets": {},
     "nuid": "622ba7ce-3f20-4820-8982-93f3d3b7be09",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "Kaiming He initialization is used here because of presence of ReLU activations in the experts. Feel free to experiment with Glorot initialization which is more commonly used in transformers. Jeremy Howard's Fastai Part 2 has an excellent lecture that implements these from scratch: https://course.fast.ai/Lessons/lesson17.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "a6d3c057-08ee-4c1b-8013-6a88b2eadac5",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "\n",
    "def kaiming_init_weights(m):\n",
    "    if isinstance (m, (nn.Linear)): \n",
    "        init.kaiming_normal_(m.weight)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "5b4d9525-8405-4a51-adda-661aba004e57",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "model = SparseMoELanguageModel()\n",
    "model.apply(kaiming_init_weights)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {},
     "inputWidgets": {},
     "nuid": "6adf1d04-e668-4d14-b691-161ea4e4dccf",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "I have used mlflow to track and log the metrics I care about and the training hyperparameters. The training loop in the next cell includes this mlflow code. If you prefer to just train without using mlflow, the subsequent cell has code without the mlflow code. However, I find it very convenient to track parameters and metrics, particularly when experimenting."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "b8968247-0d7b-4460-b96b-06743b31c55d",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "#Using MLFlow\n",
    "m = model.to(device)\n",
    "# print the number of parameters in the model\n",
    "print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')\n",
    "\n",
    "# create a PyTorch optimizer\n",
    "optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)\n",
    "#mlflow.set_experiment(\"makeMoE\")\n",
    "with mlflow.start_run():\n",
    "    #If you use mlflow.autolog() this will be automatically logged. I chose to explicitly log here for completeness\n",
    "    params = {\"batch_size\": batch_size , \"block_size\" : block_size, \"max_iters\": max_iters, \"eval_interval\": eval_interval, \"learning_rate\": learning_rate, \"device\": device, \"eval_iters\": eval_iters, \"dropout\" : dropout, \"num_experts\": num_experts, \"top_k\": top_k }\n",
    "    mlflow.log_params(params)\n",
    "    for iter in range(max_iters):\n",
    "\n",
    "        # every once in a while evaluate the loss on train and val sets\n",
    "        if iter % eval_interval == 0 or iter == max_iters - 1:\n",
    "            losses = estimate_loss()\n",
    "            print(f\"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}\")\n",
    "            metrics = {\"train_loss\": float(losses['train']), \"val_loss\": float(losses['val'])}\n",
    "            mlflow.log_metrics(metrics, step=iter)\n",
    "\n",
    "\n",
    "        # sample a batch of data\n",
    "        xb, yb = get_batch('train')\n",
    "\n",
    "        # evaluate the loss\n",
    "        logits, loss = model(xb, yb)\n",
    "        optimizer.zero_grad(set_to_none=True)\n",
    "        loss.backward()\n",
    "        optimizer.step()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {},
     "inputWidgets": {},
     "nuid": "1ed96085-c292-4624-a2cc-be8aad38df79",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "Logging train and validation losses gives you a good indication of how the training is going. The plot shows that I probably should have stopped around 4500 steps (when the validation loss jumps up a bit)\n",
    "\n",
    "![mlflow_dash](https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/images/mlflow_dash.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "6360e1b7-94c4-4ef1-a850-9bc93f49a083",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "#Not using MLflow\n",
    "m = model.to(device)\n",
    "# print the number of parameters in the model\n",
    "print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')\n",
    "\n",
    "# create a PyTorch optimizer\n",
    "optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)\n",
    "\n",
    "for iter in range(max_iters):\n",
    "\n",
    "    # every once in a while evaluate the loss on train and val sets\n",
    "    if iter % eval_interval == 0 or iter == max_iters - 1:\n",
    "        losses = estimate_loss()\n",
    "        print(f\"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}\")\n",
    "\n",
    "    # sample a batch of data\n",
    "    xb, yb = get_batch('train')\n",
    "\n",
    "    # evaluate the loss\n",
    "    logits, loss = model(xb, yb)\n",
    "    optimizer.zero_grad(set_to_none=True)\n",
    "    loss.backward()\n",
    "    optimizer.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "8aa6e4c4-c688-4985-a3b8-e2af1f771e54",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "# generate from the model. Not great. Not too bad either\n",
    "context = torch.zeros((1, 1), dtype=torch.long, device=device)\n",
    "print(decode(m.generate(context, max_new_tokens=2000)[0].tolist()))"
   ]
  }
 ],
 "metadata": {
  "application/vnd.databricks.v1+notebook": {
   "dashboards": [],
   "language": "python",
   "notebookMetadata": {
    "pythonIndentUnit": 4
   },
   "notebookName": "makeMoE_concise",
   "widgets": {}
  },
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
